{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# $$CatBoost\\ PredictionDiff \\ Feature\\ Importance\\ Tutorial$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sometimes it is very important to understand which feature made the greatest contribution to the final result. To do this, the CatBoost model has a get_feature_importance method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from catboost import CatBoost, Pool, datasets\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "First, let's prepare the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df, test_df = datasets.msrank_10k()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, y_train, group_id_train = np.array(train_df.drop([0, 1], axis=1)), np.array(train_df[0]), np.array(train_df[1])\n",
    "X_test, y_test, group_id_test = np.array(test_df.drop([0, 1], axis=1)), np.array(test_df[0]), np.array(test_df[1])\n",
    "train_pool = Pool(X_train, y_train, group_id=group_id_train)\n",
    "test_pool = Pool(X_test, y_test, group_id=group_id_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's train CatBoost:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = CatBoost({'iterations': 50, 'loss_function': 'YetiRank', 'verbose': False, 'random_seed': 42})\n",
    "model.fit(train_pool);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Catboost provides several types of feature importances. One of them is PredictionDiff: A vector with contributions of each feature to the RawFormulaVal difference for each pair of objects."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's find such pair of objects in 1-st group in test pool that our model ranks in wrong order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# find 1st group\n",
    "group_size = 1\n",
    "while group_id_test[group_size] == group_id_test[0]:\n",
    "    group_size += 1\n",
    "\n",
    "# get predictions\n",
    "target = y_test[:group_size]\n",
    "prediction = model.predict(X_test[: group_size], prediction_type='RawFormulaVal')\n",
    "prediction = zip(prediction, target, range(group_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(0.16605430089485623, 0.0), (-0.011438740254469434, 3.0)]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# find a wrong ranked pair of objects\n",
    "wrong_prediction_idxs = [\n",
    "    int(np.max([(x[0], x[2]) for x in prediction if x[1] == 0])),\n",
    "    int(np.min([(x[1], x[2]) for x in prediction if x[1] == 3]))\n",
    "]\n",
    "test_pool_slice = X_test[wrong_prediction_idxs]\n",
    "\n",
    "zip(model.predict(test_pool_slice, prediction_type='RawFormulaVal'), target[wrong_prediction_idxs])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's calculate PredictionDiff for these two objects and see most important features.\n",
    "\n",
    "As you can see, changing in the feature 133 could change model prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Feature Id</th>\n",
       "      <th>Importances</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>133</td>\n",
       "      <td>0.297181</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>107</td>\n",
       "      <td>0.141921</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>64</td>\n",
       "      <td>0.082890</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>10</td>\n",
       "      <td>0.065537</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>14</td>\n",
       "      <td>0.061547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>7</td>\n",
       "      <td>0.061138</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>59</td>\n",
       "      <td>0.058217</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>122</td>\n",
       "      <td>0.043991</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>54</td>\n",
       "      <td>0.042972</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>49</td>\n",
       "      <td>0.042607</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  Feature Id  Importances\n",
       "0        133     0.297181\n",
       "1        107     0.141921\n",
       "2         64     0.082890\n",
       "3         10     0.065537\n",
       "4         14     0.061547\n",
       "5          7     0.061138\n",
       "6         59     0.058217\n",
       "7        122     0.043991\n",
       "8         54     0.042972\n",
       "9         49     0.042607"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prediction_diff = model.get_feature_importance(type='PredictionDiff', data=test_pool_slice, prettified=True)\n",
    "prediction_diff.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/eermishkina/.local/lib/python2.7/site-packages/requests/__init__.py:83: RequestsDependencyWarning: Old version of cryptography ([1, 2, 3]) may cause slowdown.\n",
      "  warnings.warn(warning, RequestsDependencyWarning)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "        <script type=\"text/javascript\">\n",
       "        window.PlotlyConfig = {MathJaxConfig: 'local'};\n",
       "        if (window.MathJax) {MathJax.Hub.Config({SVG: {font: \"STIX-Web\"}});}\n",
       "        if (typeof require !== 'undefined') {\n",
       "        require.undef(\"plotly\");\n",
       "        requirejs.config({\n",
       "            paths: {\n",
       "                'plotly': ['https://cdn.plot.ly/plotly-latest.min']\n",
       "            }\n",
       "        });\n",
       "        require(['plotly'], function(Plotly) {\n",
       "            window._Plotly = Plotly;\n",
       "        });\n",
       "        }\n",
       "        </script>\n",
       "        "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "linkText": "Export to plot.ly",
        "plotlyServerURL": "https://plot.ly",
        "showLink": false
       },
       "data": [
        {
         "mode": "lines+markers",
         "name": "Document 0 predictions",
         "type": "scatter",
         "uid": "d37d8d08-b502-4f8e-8959-fcb41bc1b11c",
         "y": [
          0.16605430089485623,
          0.35383003383687106,
          0.41954074760562454,
          0.44601368847586687,
          0.46894550193851814,
          0.47672256738579566,
          0.482251285924186
         ]
        },
        {
         "showlegend": false,
         "type": "scatter",
         "uid": "e724cef7-4fd1-41c8-8de3-2f2e087ab99d",
         "x": [
          0
         ],
         "y": [
          0.16605430089485623
         ]
        },
        {
         "mode": "lines+markers",
         "name": "Document 1 predictions",
         "type": "scatter",
         "uid": "751980a8-e8d6-4d3e-ae6c-e21521d4dda9",
         "y": [
          -0.011438740254469434,
          0.18269462415479784,
          0.23071146299638742,
          0.2670685800041027,
          0.27363042565900536,
          0.2814200161602986,
          0.2857425781413274
         ]
        },
        {
         "showlegend": false,
         "type": "scatter",
         "uid": "cc559291-f889-42eb-940d-c2bbb3aa775f",
         "x": [
          0
         ],
         "y": [
          -0.011438740254469434
         ]
        }
       ],
       "layout": {
        "title": {
         "text": "Prediction variation for feature '133'"
        },
        "xaxis": {
         "showticklabels": false,
         "tickmode": "array",
         "ticktext": [
          "(-inf, 0.5000]",
          "(0.5000, 1.5000]",
          "(1.5000, 2.5000]",
          "(2.5000, 3.5000]",
          "(3.5000, 4.5000]",
          "(4.5000, 5.5000]",
          "(5.5000, +inf)"
         ],
         "tickvals": [
          0,
          1,
          2,
          3,
          4,
          5,
          6
         ],
         "title": {
          "text": "Bins"
         }
        },
        "yaxis": {
         "overlaying": "y2",
         "side": "left",
         "title": {
          "text": "Prediction"
         }
        }
       }
      },
      "text/html": [
       "<div>\n",
       "        \n",
       "        \n",
       "            <div id=\"00b59d7a-da71-44f9-9987-d963819ed213\" class=\"plotly-graph-div\" style=\"height:525px; width:100%;\"></div>\n",
       "            <script type=\"text/javascript\">\n",
       "                require([\"plotly\"], function(Plotly) {\n",
       "                    window.PLOTLYENV=window.PLOTLYENV || {};\n",
       "                    window.PLOTLYENV.BASE_URL='https://plot.ly';\n",
       "                    \n",
       "                if (document.getElementById(\"00b59d7a-da71-44f9-9987-d963819ed213\")) {\n",
       "                    Plotly.newPlot(\n",
       "                        '00b59d7a-da71-44f9-9987-d963819ed213',\n",
       "                        [{\"mode\": \"lines+markers\", \"name\": \"Document 0 predictions\", \"type\": \"scatter\", \"uid\": \"d37d8d08-b502-4f8e-8959-fcb41bc1b11c\", \"y\": [0.16605430089485623, 0.35383003383687106, 0.41954074760562454, 0.44601368847586687, 0.46894550193851814, 0.47672256738579566, 0.482251285924186]}, {\"showlegend\": false, \"type\": \"scatter\", \"uid\": \"e724cef7-4fd1-41c8-8de3-2f2e087ab99d\", \"x\": [0], \"y\": [0.16605430089485623]}, {\"mode\": \"lines+markers\", \"name\": \"Document 1 predictions\", \"type\": \"scatter\", \"uid\": \"751980a8-e8d6-4d3e-ae6c-e21521d4dda9\", \"y\": [-0.011438740254469434, 0.18269462415479784, 0.23071146299638742, 0.2670685800041027, 0.27363042565900536, 0.2814200161602986, 0.2857425781413274]}, {\"showlegend\": false, \"type\": \"scatter\", \"uid\": \"cc559291-f889-42eb-940d-c2bbb3aa775f\", \"x\": [0], \"y\": [-0.011438740254469434]}],\n",
       "                        {\"title\": {\"text\": \"Prediction variation for feature '133'\"}, \"xaxis\": {\"showticklabels\": false, \"tickmode\": \"array\", \"ticktext\": [\"(-inf, 0.5000]\", \"(0.5000, 1.5000]\", \"(1.5000, 2.5000]\", \"(2.5000, 3.5000]\", \"(3.5000, 4.5000]\", \"(4.5000, 5.5000]\", \"(5.5000, +inf)\"], \"tickvals\": [0, 1, 2, 3, 4, 5, 6], \"title\": {\"text\": \"Bins\"}}, \"yaxis\": {\"overlaying\": \"y2\", \"side\": \"left\", \"title\": {\"text\": \"Prediction\"}}},\n",
       "                        {\"responsive\": true, \"plotlyServerURL\": \"https://plot.ly\", \"linkText\": \"Export to plot.ly\", \"showLink\": false}\n",
       "                    ).then(function(){\n",
       "                            \n",
       "var gd = document.getElementById('00b59d7a-da71-44f9-9987-d963819ed213');\n",
       "var x = new MutationObserver(function (mutations, observer) {{\n",
       "        var display = window.getComputedStyle(gd).display;\n",
       "        if (!display || display === 'none') {{\n",
       "            console.log([gd, 'removed!']);\n",
       "            Plotly.purge(gd);\n",
       "            observer.disconnect();\n",
       "        }}\n",
       "}});\n",
       "\n",
       "// Listen for the removal of the full notebook cells\n",
       "var notebookContainer = gd.closest('#notebook-container');\n",
       "if (notebookContainer) {{\n",
       "    x.observe(notebookContainer, {childList: true});\n",
       "}}\n",
       "\n",
       "// Listen for the clearing of the current output cell\n",
       "var outputEl = gd.closest('.output');\n",
       "if (outputEl) {{\n",
       "    x.observe(outputEl, {childList: true});\n",
       "}}\n",
       "\n",
       "                        })\n",
       "                };\n",
       "                });\n",
       "            </script>\n",
       "        </div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "linkText": "Export to plot.ly",
        "plotlyServerURL": "https://plot.ly",
        "showLink": false
       },
       "data": [
        {
         "mode": "lines+markers",
         "name": "Document 0 predictions",
         "type": "scatter",
         "uid": "5478b144-ec39-40fe-a277-61ae6c7e8fec",
         "y": [
          0.09835243711523532,
          0.09924539615545591,
          0.10911821439932881,
          0.10782603649658631,
          0.1321677803350289,
          0.13403279166265059,
          0.14147759038675708,
          0.14316853896424384,
          0.14640684760946618,
          0.14908903728297304,
          0.15151276379035405,
          0.15607157085858223,
          0.1586932564760203,
          0.16068298327545688,
          0.16263739291763543,
          0.16499674576640372,
          0.16605430089485623,
          0.16494264895340638,
          0.16801383173293016,
          0.17009157895592567,
          0.17229698264279386,
          0.17605279953254074,
          0.1777200954259589,
          0.18140453723203487,
          0.18520228732947683,
          0.1875489676204623,
          0.1882314482296358,
          0.1882314482296358
         ]
        },
        {
         "showlegend": false,
         "type": "scatter",
         "uid": "4839aa63-2fd4-4f52-a9ed-d7d8e1f2b533",
         "x": [
          16
         ],
         "y": [
          0.16605430089485623
         ]
        },
        {
         "mode": "lines+markers",
         "name": "Document 1 predictions",
         "type": "scatter",
         "uid": "911c81b5-d2f8-41f6-bf21-cac888204151",
         "y": [
          -0.011438740254469434,
          -0.00797015053303532,
          -0.003799975073160944,
          0.001027881229450369,
          0.007245717874586698,
          0.010992392920974339,
          0.009466717798344085,
          0.010764507981792886,
          0.0064607835070949445,
          0.00850918968102091,
          0.009133195806820631,
          0.012267219309929637,
          0.01567128968539622,
          0.016422341873313498,
          0.020833837511855146,
          0.023193190360623436,
          0.022980971974614108,
          0.024683545273927958,
          0.026493995324526346,
          0.029938614890347916,
          0.03214401857721611,
          0.0357971774978802,
          0.03682167323623228,
          0.03936792541966147,
          0.04769236440538081,
          0.057055281811225,
          0.05641230962880993,
          0.05641230962880993
         ]
        },
        {
         "showlegend": false,
         "type": "scatter",
         "uid": "2602ef85-02a9-4946-84c6-9dcca736bdea",
         "x": [
          0
         ],
         "y": [
          -0.011438740254469434
         ]
        }
       ],
       "layout": {
        "title": {
         "text": "Prediction variation for feature '107'"
        },
        "xaxis": {
         "showticklabels": false,
         "tickmode": "array",
         "ticktext": [
          "(-inf, 8.6009]",
          "(8.6009, 9.1763]",
          "(9.1763, 9.3270]",
          "(9.3270, 9.4949]",
          "(9.4949, 9.6187]",
          "(9.6187, 9.6870]",
          "(9.6870, 10.0166]",
          "(10.0166, 10.4506]",
          "(10.4506, 10.5185]",
          "(10.5185, 10.7824]",
          "(10.7824, 10.8885]",
          "(10.8885, 11.0104]",
          "(11.0104, 11.0448]",
          "(11.0448, 11.3880]",
          "(11.3880, 11.7007]",
          "(11.7007, 11.7811]",
          "(11.7811, 15.3302]",
          "(15.3302, 15.5215]",
          "(15.5215, 16.5321]",
          "(16.5321, 17.0883]",
          "(17.0883, 17.7997]",
          "(17.7997, 17.9252]",
          "(17.9252, 18.3487]",
          "(18.3487, 18.6351]",
          "(18.6351, 18.8729]",
          "(18.8729, 18.9773]",
          "(18.9773, 22.0085]",
          "(22.0085, +inf)"
         ],
         "tickvals": [
          0,
          1,
          2,
          3,
          4,
          5,
          6,
          7,
          8,
          9,
          10,
          11,
          12,
          13,
          14,
          15,
          16,
          17,
          18,
          19,
          20,
          21,
          22,
          23,
          24,
          25,
          26,
          27
         ],
         "title": {
          "text": "Bins"
         }
        },
        "yaxis": {
         "overlaying": "y2",
         "side": "left",
         "title": {
          "text": "Prediction"
         }
        }
       }
      },
      "text/html": [
       "<div>\n",
       "        \n",
       "        \n",
       "            <div id=\"7a5fe817-0a05-4ae9-9de4-f8e32b6df0f2\" class=\"plotly-graph-div\" style=\"height:525px; width:100%;\"></div>\n",
       "            <script type=\"text/javascript\">\n",
       "                require([\"plotly\"], function(Plotly) {\n",
       "                    window.PLOTLYENV=window.PLOTLYENV || {};\n",
       "                    window.PLOTLYENV.BASE_URL='https://plot.ly';\n",
       "                    \n",
       "                if (document.getElementById(\"7a5fe817-0a05-4ae9-9de4-f8e32b6df0f2\")) {\n",
       "                    Plotly.newPlot(\n",
       "                        '7a5fe817-0a05-4ae9-9de4-f8e32b6df0f2',\n",
       "                        [{\"mode\": \"lines+markers\", \"name\": \"Document 0 predictions\", \"type\": \"scatter\", \"uid\": \"5478b144-ec39-40fe-a277-61ae6c7e8fec\", \"y\": [0.09835243711523532, 0.09924539615545591, 0.10911821439932881, 0.10782603649658631, 0.1321677803350289, 0.13403279166265059, 0.14147759038675708, 0.14316853896424384, 0.14640684760946618, 0.14908903728297304, 0.15151276379035405, 0.15607157085858223, 0.1586932564760203, 0.16068298327545688, 0.16263739291763543, 0.16499674576640372, 0.16605430089485623, 0.16494264895340638, 0.16801383173293016, 0.17009157895592567, 0.17229698264279386, 0.17605279953254074, 0.1777200954259589, 0.18140453723203487, 0.18520228732947683, 0.1875489676204623, 0.1882314482296358, 0.1882314482296358]}, {\"showlegend\": false, \"type\": \"scatter\", \"uid\": \"4839aa63-2fd4-4f52-a9ed-d7d8e1f2b533\", \"x\": [16], \"y\": [0.16605430089485623]}, {\"mode\": \"lines+markers\", \"name\": \"Document 1 predictions\", \"type\": \"scatter\", \"uid\": \"911c81b5-d2f8-41f6-bf21-cac888204151\", \"y\": [-0.011438740254469434, -0.00797015053303532, -0.003799975073160944, 0.001027881229450369, 0.007245717874586698, 0.010992392920974339, 0.009466717798344085, 0.010764507981792886, 0.0064607835070949445, 0.00850918968102091, 0.009133195806820631, 0.012267219309929637, 0.01567128968539622, 0.016422341873313498, 0.020833837511855146, 0.023193190360623436, 0.022980971974614108, 0.024683545273927958, 0.026493995324526346, 0.029938614890347916, 0.03214401857721611, 0.0357971774978802, 0.03682167323623228, 0.03936792541966147, 0.04769236440538081, 0.057055281811225, 0.05641230962880993, 0.05641230962880993]}, {\"showlegend\": false, \"type\": \"scatter\", \"uid\": \"2602ef85-02a9-4946-84c6-9dcca736bdea\", \"x\": [0], \"y\": [-0.011438740254469434]}],\n",
       "                        {\"title\": {\"text\": \"Prediction variation for feature '107'\"}, \"xaxis\": {\"showticklabels\": false, \"tickmode\": \"array\", \"ticktext\": [\"(-inf, 8.6009]\", \"(8.6009, 9.1763]\", \"(9.1763, 9.3270]\", \"(9.3270, 9.4949]\", \"(9.4949, 9.6187]\", \"(9.6187, 9.6870]\", \"(9.6870, 10.0166]\", \"(10.0166, 10.4506]\", \"(10.4506, 10.5185]\", \"(10.5185, 10.7824]\", \"(10.7824, 10.8885]\", \"(10.8885, 11.0104]\", \"(11.0104, 11.0448]\", \"(11.0448, 11.3880]\", \"(11.3880, 11.7007]\", \"(11.7007, 11.7811]\", \"(11.7811, 15.3302]\", \"(15.3302, 15.5215]\", \"(15.5215, 16.5321]\", \"(16.5321, 17.0883]\", \"(17.0883, 17.7997]\", \"(17.7997, 17.9252]\", \"(17.9252, 18.3487]\", \"(18.3487, 18.6351]\", \"(18.6351, 18.8729]\", \"(18.8729, 18.9773]\", \"(18.9773, 22.0085]\", \"(22.0085, +inf)\"], \"tickvals\": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27], \"title\": {\"text\": \"Bins\"}}, \"yaxis\": {\"overlaying\": \"y2\", \"side\": \"left\", \"title\": {\"text\": \"Prediction\"}}},\n",
       "                        {\"responsive\": true, \"plotlyServerURL\": \"https://plot.ly\", \"linkText\": \"Export to plot.ly\", \"showLink\": false}\n",
       "                    ).then(function(){\n",
       "                            \n",
       "var gd = document.getElementById('7a5fe817-0a05-4ae9-9de4-f8e32b6df0f2');\n",
       "var x = new MutationObserver(function (mutations, observer) {{\n",
       "        var display = window.getComputedStyle(gd).display;\n",
       "        if (!display || display === 'none') {{\n",
       "            console.log([gd, 'removed!']);\n",
       "            Plotly.purge(gd);\n",
       "            observer.disconnect();\n",
       "        }}\n",
       "}});\n",
       "\n",
       "// Listen for the removal of the full notebook cells\n",
       "var notebookContainer = gd.closest('#notebook-container');\n",
       "if (notebookContainer) {{\n",
       "    x.observe(notebookContainer, {childList: true});\n",
       "}}\n",
       "\n",
       "// Listen for the clearing of the current output cell\n",
       "var outputEl = gd.closest('.output');\n",
       "if (outputEl) {{\n",
       "    x.observe(outputEl, {childList: true});\n",
       "}}\n",
       "\n",
       "                        })\n",
       "                };\n",
       "                });\n",
       "            </script>\n",
       "        </div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "linkText": "Export to plot.ly",
        "plotlyServerURL": "https://plot.ly",
        "showLink": false
       },
       "data": [
        {
         "mode": "lines+markers",
         "name": "Document 0 predictions",
         "type": "scatter",
         "uid": "a6f98a12-84ee-4f37-a94d-17d051658c14",
         "y": [
          0.08316382062173322,
          0.09091802839928032,
          0.10790403294127493,
          0.13676822275769032,
          0.13998224519831784,
          0.1453622492173347,
          0.14800342241967096,
          0.15018503327410856,
          0.15338816225873206,
          0.15744141596546354,
          0.1643677254669623,
          0.16605430089485623
         ]
        },
        {
         "showlegend": false,
         "type": "scatter",
         "uid": "c66c6700-b7df-40e6-94aa-509058f198ee",
         "x": [
          11
         ],
         "y": [
          0.16605430089485623
         ]
        },
        {
         "mode": "lines+markers",
         "name": "Document 1 predictions",
         "type": "scatter",
         "uid": "7d90569f-4426-4578-b1c3-35ca5db6a26f",
         "y": [
          -0.058213851090722114,
          -0.051432469617469334,
          -0.054311810997726855,
          -0.04727557932292831,
          -0.04265383535961641,
          -0.038764401530749185,
          -0.03516663551621173,
          -0.030576777862756043,
          -0.027743659252993196,
          -0.02387748941899009,
          -0.016387902402024628,
          -0.011438740254469434
         ]
        },
        {
         "showlegend": false,
         "type": "scatter",
         "uid": "938c7084-0990-4b91-8461-902ff357f664",
         "x": [
          11
         ],
         "y": [
          -0.011438740254469434
         ]
        }
       ],
       "layout": {
        "title": {
         "text": "Prediction variation for feature '64'"
        },
        "xaxis": {
         "showticklabels": false,
         "tickmode": "array",
         "ticktext": [
          "(-inf, 0.0077]",
          "(0.0077, 0.0078]",
          "(0.0078, 0.0088]",
          "(0.0088, 0.0103]",
          "(0.0103, 0.0109]",
          "(0.0109, 0.0127]",
          "(0.0127, 0.0128]",
          "(0.0128, 0.0130]",
          "(0.0130, 0.0137]",
          "(0.0137, 0.0139]",
          "(0.0139, 0.0141]",
          "(0.0141, +inf)"
         ],
         "tickvals": [
          0,
          1,
          2,
          3,
          4,
          5,
          6,
          7,
          8,
          9,
          10,
          11
         ],
         "title": {
          "text": "Bins"
         }
        },
        "yaxis": {
         "overlaying": "y2",
         "side": "left",
         "title": {
          "text": "Prediction"
         }
        }
       }
      },
      "text/html": [
       "<div>\n",
       "        \n",
       "        \n",
       "            <div id=\"8e30b0ef-1edf-4903-a8c7-ce666e438a40\" class=\"plotly-graph-div\" style=\"height:525px; width:100%;\"></div>\n",
       "            <script type=\"text/javascript\">\n",
       "                require([\"plotly\"], function(Plotly) {\n",
       "                    window.PLOTLYENV=window.PLOTLYENV || {};\n",
       "                    window.PLOTLYENV.BASE_URL='https://plot.ly';\n",
       "                    \n",
       "                if (document.getElementById(\"8e30b0ef-1edf-4903-a8c7-ce666e438a40\")) {\n",
       "                    Plotly.newPlot(\n",
       "                        '8e30b0ef-1edf-4903-a8c7-ce666e438a40',\n",
       "                        [{\"mode\": \"lines+markers\", \"name\": \"Document 0 predictions\", \"type\": \"scatter\", \"uid\": \"a6f98a12-84ee-4f37-a94d-17d051658c14\", \"y\": [0.08316382062173322, 0.09091802839928032, 0.10790403294127493, 0.13676822275769032, 0.13998224519831784, 0.1453622492173347, 0.14800342241967096, 0.15018503327410856, 0.15338816225873206, 0.15744141596546354, 0.1643677254669623, 0.16605430089485623]}, {\"showlegend\": false, \"type\": \"scatter\", \"uid\": \"c66c6700-b7df-40e6-94aa-509058f198ee\", \"x\": [11], \"y\": [0.16605430089485623]}, {\"mode\": \"lines+markers\", \"name\": \"Document 1 predictions\", \"type\": \"scatter\", \"uid\": \"7d90569f-4426-4578-b1c3-35ca5db6a26f\", \"y\": [-0.058213851090722114, -0.051432469617469334, -0.054311810997726855, -0.04727557932292831, -0.04265383535961641, -0.038764401530749185, -0.03516663551621173, -0.030576777862756043, -0.027743659252993196, -0.02387748941899009, -0.016387902402024628, -0.011438740254469434]}, {\"showlegend\": false, \"type\": \"scatter\", \"uid\": \"938c7084-0990-4b91-8461-902ff357f664\", \"x\": [11], \"y\": [-0.011438740254469434]}],\n",
       "                        {\"title\": {\"text\": \"Prediction variation for feature '64'\"}, \"xaxis\": {\"showticklabels\": false, \"tickmode\": \"array\", \"ticktext\": [\"(-inf, 0.0077]\", \"(0.0077, 0.0078]\", \"(0.0078, 0.0088]\", \"(0.0088, 0.0103]\", \"(0.0103, 0.0109]\", \"(0.0109, 0.0127]\", \"(0.0127, 0.0128]\", \"(0.0128, 0.0130]\", \"(0.0130, 0.0137]\", \"(0.0137, 0.0139]\", \"(0.0139, 0.0141]\", \"(0.0141, +inf)\"], \"tickvals\": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], \"title\": {\"text\": \"Bins\"}}, \"yaxis\": {\"overlaying\": \"y2\", \"side\": \"left\", \"title\": {\"text\": \"Prediction\"}}},\n",
       "                        {\"responsive\": true, \"plotlyServerURL\": \"https://plot.ly\", \"linkText\": \"Export to plot.ly\", \"showLink\": false}\n",
       "                    ).then(function(){\n",
       "                            \n",
       "var gd = document.getElementById('8e30b0ef-1edf-4903-a8c7-ce666e438a40');\n",
       "var x = new MutationObserver(function (mutations, observer) {{\n",
       "        var display = window.getComputedStyle(gd).display;\n",
       "        if (!display || display === 'none') {{\n",
       "            console.log([gd, 'removed!']);\n",
       "            Plotly.purge(gd);\n",
       "            observer.disconnect();\n",
       "        }}\n",
       "}});\n",
       "\n",
       "// Listen for the removal of the full notebook cells\n",
       "var notebookContainer = gd.closest('#notebook-container');\n",
       "if (notebookContainer) {{\n",
       "    x.observe(notebookContainer, {childList: true});\n",
       "}}\n",
       "\n",
       "// Listen for the clearing of the current output cell\n",
       "var outputEl = gd.closest('.output');\n",
       "if (outputEl) {{\n",
       "    x.observe(outputEl, {childList: true});\n",
       "}}\n",
       "\n",
       "                        })\n",
       "                };\n",
       "                });\n",
       "            </script>\n",
       "        </div>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model.plot_predictions(\n",
    "    data=test_pool_slice,\n",
    "    features_to_change=prediction_diff[\"Feature Id\"][:3],\n",
    "    plot=True);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
